'''
1、筛选未训练的图片
2、判断是否为有效图片
3、把图片复制到目的目录
4、读取图片的类型、参数
'''
import os

import numpy as np
from keras.applications.imagenet_utils import preprocess_input, decode_predictions
from keras.models import load_model
from keras.preprocessing import image
from scipy.misc import imread, imresize
from keras import backend as K

not_header = ['商品名称', '店铺', '商品毛重', '商品编号', '上市时间', '货号', '商品产地',
              '表盘颜色', '适用年龄', '尺码', '适用年龄', '销售渠道类型', '材质成分', '品牌', '品牌名称',
              '上市年份季节', '年份季节', '配件/备注', '产地', '洗涤说明', '面料成分', '材质', '大码女装分类', '分类', '适用季节',
              '面料', '成分含量', '质地', '尺寸', '服装款式细节', '材质1', '自定义', '材质3', '深灰色', '黑色', '深蓝', '深蓝2.0',
              '温馨提示', '藏青色', '灰色材质', '成分', '材质2', '蓝/黑色材质', '白色材质', '纤维成份', '里料', '厚薄', '厚度',
              '上市年份/季节', '年份/季节', '街头', '甜美', '主要颜色', '规格', '罩杯材质', '中老年女装分类', '穿着方式', '尺码', '品牌',
              '通勤', '流行元素', '中老年风格', '服饰工艺', '女装质地', '适用场景', '网纱', '组合形式', '毛线粗细', '花朵', '流行元素',
              '工艺', '面料分类', '风格', '中老年女装图案', '适用对象', '克重', '组合规格', '里料分类', '主要材质', '图案',
              '填充物', '牛仔面料盎司', '功能', '设计裁剪', '人群', '图案文化', '流行元素/工艺', '流行元素/工艺', '弹力', '2.0材质',
              '适用场合', '重量', '类型', '弹性', '款式', '选购热点', '适用人群', '适合人群', '面料成份'
              ]

# 3 6 15 6 5 6 5
catalog = [93, 94, 95]  # ['上装','裙装','裤装','特色服饰']
sleeve = ['长袖', '短袖', '无袖', '五分袖', '七分袖', '九分袖']  # 袖长
# collar = ['V领', '不规则领', '圆领', '低圆领', '其它', '高领', '半高领', '翻领', '连帽', '围巾领', 'A字型', '立领', '半高圆领',
#           '一字领', '方领', '旗袍领', 'POLO领', '堆堆领', '无领', '椭圆领', '半开领', 'U型领', '八字领', '高圆领', '低领', '常规领',
#           '娃娃领', '鸡心领', '平领', '双层领', '披肩领', '荡领', '橄榄领', '海军领', '大尖领', '系带领', '无', '荷叶领', '挂脖']  # 领型
collar = ['圆领', '其它', 'V领', '高领', '娃娃领', '立领', '方领', '翻领', '海军领', '半高领', 'POLO领', '连帽', '一字领', '半高圆领', '荷叶领']
coat_length = ['超短款', '短款', '常规款', '中长款', '长款', '不对称衣长']  # 衣长
dress_length = ['超短裙', '短裙', '中裙', '中长裙', '长裙']  # 裙长
outseam = ['短裤', '长裤', '热裤', '五分裤', '七分裤', '九分裤']  # 裤长
waistline = ['低腰', '中腰', '高腰', '松紧腰', '调节腰']  # 腰型


def triplet_loss(y_true, y_pred):
    y_pred = K.l2_normalize(y_pred, axis=1)
    batch = 20
    ref1 = y_pred[0:batch, :]
    pos1 = y_pred[batch:batch + batch, :]
    neg1 = y_pred[batch + batch:3 * batch, :]
    dis_pos = K.sum(K.square(ref1 - pos1), axis=1, keepdims=True)
    dis_neg = K.sum(K.square(ref1 - neg1), axis=1, keepdims=True)
    dis_pos = K.sqrt(dis_pos)
    dis_neg = K.sqrt(dis_neg)
    a1 = 0.2
    d1 = K.maximum(0.0, dis_pos - dis_neg + a1)
    return K.mean(d1)


# 加载训练好的模型
def load_models(model_dir):
    useless_model = load_model(os.path.join(model_dir, 'unopen_model.h5'))
    triplet_model = load_model(os.path.join(model_dir, 'triplet_resnet50.h5'))
    triplet_split_model = load_model(os.path.join(model_dir, 'triplet_split_resnet50.h5'))
    return useless_model, triplet_model, triplet_split_model


# 判断商品是否训练过
def is_trained(good_id, trained_dir):
    if good_id in os.listdir(trained_dir):
        return True
    return False


# 判断是否为有效图片
def is_canuse(model, img_data):
    result = model.predict(img_data)
    if np.argmax(result[0]) > 0:
        return False
    return True


def get_img_array(img_path, image_size=224):
    img = image.load_img(img_path, target_size=(image_size, image_size))
    img_data = image.img_to_array(img)
    img_data = np.expand_dims(img_data, axis=0)  # 这三行的意思是提取 图片特征
    img_data = preprocess_input(img_data)
    return img_data


if __name__ == '__main__':
    # img_path = 'D:/cdn1/comment/4519671/5a70200dN99bb9cd3.jpg'
    # # useless_model, triplet_model, triplet_split_model = load_models('E:/cdn/models')
    # useless_model = load_model(os.path.join('E:/cdn/models', 'unopen_model.h5'))
    # img = image.load_img(img_path, target_size=(224, 224))
    # img_data = image.img_to_array(img)
    # img_data = np.expand_dims(img_data, axis=0)  # 这三行的意思是提取 图片特征
    # img_data = preprocess_input(img_data)
    # result = useless_model.predict(img_data)
    # print(result.shape)  # [1,3]
    # print(result[0])
    # print(np.argmax(result[0]))
    a=[0,1,0,0,1]
    b = [0,0,0,1,0]
    pred = np.equal(a,b)
    print(np.mean(pred))
    # src_dir = 'D:/cdn1/comment'
    # trained_dir = 'E:/cdn/datasets/triplet'
    # good_ids = os.listdir(src_dir)
    # for good_id in good_ids:
    #     if not is_trained(good_id,trained_dir):
    #         for name in os.listdir(src_dir,good_id):
    #             if
